# Copyright 2017 Google Inc. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Utility functions specifically for NMT."""
from __future__ import print_function

import codecs
import time
import numpy as np
import tensorflow as tf
import logging as log
from ..utils import evaluation_utils
from ..utils import misc_utils as utils


__all__ = ["decode_and_evaluate", "get_translation"]


def decode_and_evaluate(
    run,
    iterations,
    name,
    model,
    sess,
    trans_file,
    ref_file,
    metrics,
    subword_option,
    beam_width,
    tgt_eos,
    num_translations_per_input=1,
    decode=True,
    infer_mode="greedy",
):
    """Decode a test set and compute a score according to the evaluation task."""
    # Decode
    if decode:
        utils.print_out("  decoding to output %s" % trans_file)

        num_sentences = 0
        with codecs.getwriter("utf-8")(
            tf.gfile.GFile(trans_file, mode="wb")
        ) as trans_f:
            trans_f.write("")  # Write empty string to ensure file is created.

            if infer_mode == "greedy":
                num_translations_per_input = 1
            elif infer_mode == "beam_search":
                num_translations_per_input = min(
                    num_translations_per_input, beam_width)

            print(
                "  infer_mode %s, beam_width %g, num translations per input %d. "
                % (infer_mode, beam_width, num_translations_per_input)
            )
            print("  total iterations count %d." % iterations)

            # prediction time is the time for the model prediction only
            # overall time is the time for data pre-processing and data
            # post-processing
            prediction_times = list()
            overall_start = time.time()

            n = 0
            while n < iterations:
                n += 1
                while True:
                    try:
                        start = time.time()
                        nmt_outputs, _ = model.decode(sess)
                        prediction_times.append(time.time() - start)
                        if infer_mode != "beam_search":
                            nmt_outputs = np.expand_dims(nmt_outputs, 0)

                        batch_size = nmt_outputs.shape[1]

                        num_sentences += batch_size
                        for sent_id in range(batch_size):
                            for beam_id in range(num_translations_per_input):
                                translation = get_translation(
                                    nmt_outputs[beam_id],
                                    sent_id,
                                    tgt_eos=tgt_eos,
                                    subword_option=subword_option,
                                )
                                if run == "accuracy":
                                    trans_f.write(
                                        (translation + b"\n").decode("utf-8"))

                    except tf.errors.OutOfRangeError:
                        utils.print_time(
                            "  done, num sentences %d, num translations per input %d"
                            % (num_sentences, num_translations_per_input),
                            overall_start,
                        )
                        break

        overall_time = time.time() - overall_start
        if run == "performance":
            print(
                "\nAverage Prediction Latency: {:.5f} sec per batch.".format(
                    sum(prediction_times) / float(len(prediction_times))
                )
            )
            print(
                "Overall Latency: {:.5f} sec for the entire test "
                "dataset.".format(overall_time / float(iterations))
            )
            print(
                "Overall Throughput : {:.3f} sentences per sec.".format(
                    num_sentences / float(overall_time)
                )
            )

    # Evaluation
    evaluation_scores = {}
    if ref_file and tf.gfile.Exists(trans_file):
        for metric in metrics:
            score = evaluation_utils.evaluate(
                ref_file, trans_file, metric, subword_option=subword_option
            )
            evaluation_scores[metric] = score
            utils.print_out("  %s %s: %.1f" % (metric, name, score))

    return evaluation_scores


def get_translation(nmt_outputs, sent_id, tgt_eos, subword_option):
    """Given batch decoding outputs, select a sentence and turn to text."""
    if tgt_eos:
        tgt_eos = tgt_eos.encode("utf-8")
    # Select a sentence
    output = nmt_outputs[sent_id, :].tolist()

    # If there is an eos symbol in outputs, cut them at that point.
    if tgt_eos and tgt_eos in output:
        output = output[: output.index(tgt_eos)]

    if subword_option == "bpe":  # BPE
        translation = utils.format_bpe_text(output)
    elif subword_option == "spm":  # SPM
        translation = utils.format_spm_text(output)
    else:
        translation = utils.format_text(output)

    return translation
